Conser-vision Wildlife Image Classification¶

Camera traps are a tool used by conservationists to study and monitor a wide range of ecologies while limiting human interference. However, they also generate a vast amount of data that quickly exceeds the capacity of humans to sift through. That's where machine learning can help! Advances in computer vision can help automate tasks like species detection and identification, so that humans can spend more time learning from and protecting these ecologies.

This post walks through an initial approach for the Conservision Practice Area challenge on DrivenData, a practice competition where you identify animal species in a real world dataset of wildlife images from Tai National Park in Côte d'Ivoire. This is a practice competition designed to be accessible to participants at all levels. That makes it a great place to dive into the world of data science competitions and computer vision.

Models used:¶

1) YOLOv8 
2) ResNet50
3) Vanilla CNN

Pipeline:¶


Read data (csv files) --> descriptive statistics --> data split --> directories for each split --> YOLOv8

YOLOv8:  loading --> training --(validaiting)--> testing --> image processing --> interpretation

ResNet50: image processing --> dataloaders --> defining model --> training --(validation)--> testing --> interpretation

Imports¶

In [203]:
from IPython.display import clear_output

import os
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
# import ultralytics
# from ultralytics import YOLO
from torch.utils.data import Dataset
from torchvision import models
import torch.nn.functional as F
import torchvision.models as models
from pytorch_lightning.callbacks import EarlyStopping
from torchvision import transforms
import torchvision.transforms.functional as TF
# from captum.attr import IntegratedGradients
%matplotlib inline
In [162]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
    
clear_output()
In [213]:
# Global variables

IMAGE_SIZE = 224

# change as per your paths
HOME_DIR = "/kaggle/input/classification-dataset/"
OUT_DIR = '/kaggle/working/'
IMAGE_DIR = HOME_DIR + "train_features/"
YOLO_ROOT_DIR = HOME_DIR + "YOLO_dataset/"
TEST_DIR = HOME_DIR + 'test_features/'
YOLO_TEST_DIR = OUT_DIR + 'test/'

os.chdir(HOME_DIR)   # current working directory 
print(os.getcwd())
/kaggle/input/classification-dataset
In [164]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cuda

Read data¶

In [165]:
train_features = pd.read_csv(HOME_DIR + "/train_features.csv", index_col = "id")
train_labels = pd.read_csv(HOME_DIR + "/train_labels.csv", index_col = "id")
In [166]:
train_features.head()
Out[166]:
filepath site
id
ZJ000000 train_features/ZJ000000.jpg S0120
ZJ000001 train_features/ZJ000001.jpg S0069
ZJ000002 train_features/ZJ000002.jpg S0009
ZJ000003 train_features/ZJ000003.jpg S0008
ZJ000004 train_features/ZJ000004.jpg S0036
In [167]:
train_labels.head()
Out[167]:
antelope_duiker bird blank civet_genet hog leopard monkey_prosimian rodent
id
ZJ000000 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
ZJ000001 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
ZJ000002 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
ZJ000003 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
ZJ000004 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
In [168]:
species_labels = sorted(train_labels.columns.unique())
species_labels
Out[168]:
['antelope_duiker',
 'bird',
 'blank',
 'civet_genet',
 'hog',
 'leopard',
 'monkey_prosimian',
 'rodent']
In [169]:
# Visualize some images

import matplotlib.image as mpimg

random_state = 42

# Grid with 8 positions, one for each label (7 species, plus blanks)
fig, axes = plt.subplots(nrows = 4, ncols = 2, figsize = (20, 20))

# iterate through each species
for species, ax in zip(species_labels, axes.flat):
    # get an image ID for this species
    img_id = (
        train_labels[train_labels.loc[:,species] == 1]
        .sample(1, random_state = random_state)
        .index[0]
    )
    
    # reads the filepath and returns a numpy array
    img = mpimg.imread(train_features.loc[img_id].filepath)
    
    # plot etc
    ax.imshow(img)
    ax.set_title(f"{img_id} | {species}")

Descriptive statistics¶

In [170]:
train_labels.sum().sort_values(ascending = False)
Out[170]:
monkey_prosimian    2492.0
antelope_duiker     2474.0
civet_genet         2423.0
leopard             2254.0
blank               2213.0
rodent              2013.0
bird                1641.0
hog                  978.0
dtype: float64
In [171]:
train_labels.sum().divide(train_labels.shape[0]).sort_values(ascending = False)
Out[171]:
monkey_prosimian    0.151140
antelope_duiker     0.150049
civet_genet         0.146955
leopard             0.136705
blank               0.134219
rodent              0.122089
bird                0.099527
hog                 0.059316
dtype: float64

Data Split¶

From the original training data containing 16,448 images, 10% is allocated as testing data giving 1648 images. From the remaining 14,803 images, 20% is allocated as validation data giving 2968 images.

In [172]:
from sklearn.model_selection import train_test_split

frac = 1.0   # using all the train data

y = train_labels.sample(frac = frac, random_state = 1)
x = train_features.loc[y.index].filepath.to_frame()

# trai-test split (10% of the train data)
x_train, x_test, y_train, y_test = train_test_split(
    x, y, stratify = y, test_size = 0.10
)

# train-val split (20% of the remaining train data)
x_train, x_val, y_train, y_val = train_test_split(
    x_train, y_train, stratify = y_train, test_size = 0.20
)
In [173]:
print(len(x_train), len(x_val), len(x_test))
11871 2968 1649

Creating subdirectories for each split (to populate the respective images into)¶

In [212]:
# # Need to run this cell only once!

# import os
# import shutil

# HOME_DIR = '/kaggle/working'
# IMAGE_DIR = "/kaggle/input/classification-dataset/train_features/"

# os.makedirs(os.path.join(HOME_DIR, 'train'), exist_ok=True)
# os.makedirs(os.path.join(HOME_DIR, 'val'), exist_ok=True)
# os.makedirs(os.path.join(HOME_DIR, 'test'), exist_ok=True)

# # Function to copy an image with its label directory
# def copy_image(filepath, target_dir):
#     image_name = os.path.basename(filepath)
#     src_path = os.path.join(IMAGE_DIR, image_name)  # Update source path based on your dataset structure
#     target_path = os.path.join(target_dir, image_name)
#     shutil.copyfile(src_path, target_path)

# # Sample dataframes x_train, x_val, x_test are assumed to be defined earlier

# # Training Data
# for index, row in x_train.iterrows():
#     copy_image(row['filepath'], os.path.join(HOME_DIR, 'train'))

# # Validation Data
# for index, row in x_val.iterrows():
#     copy_image(row['filepath'], os.path.join(HOME_DIR, 'val'))

# # Test Data
# for index, row in x_test.iterrows():
#     copy_image(row['filepath'], os.path.join(HOME_DIR, 'test'))

# # Verify the directory structure and copied files
# print("Directory structure and files copied successfully!")
Directory structure and files copied successfully!

------------------- At this point, you should run the Segregate_data.py script and create the data structure expected by YOLOv8 model! -------------------¶

Note: For YOLO model, we create seperate subfolders for each data class. Tis is quintessential for training/validating the model, however, for testing, direct path to your testing data images is required meaning we can use the directory for test images generated after we created the train-val-test splits.

In [175]:
# def save_model(model):
#     print(model, file = open('YOLO_summary.txt', "w"))  # yolo summary architecture 

YOLOv8 model¶

Loading and training¶

In [176]:
# Choices of YOLOv8 models:
# yolo_choices = ['yolov8n-cls.pt', 'yolov8s-cls.pt', 'yolov8m-cls.pt', 'yolov8l-cls.pt', 'yolov8x-cls.pt']

# yolo_model = YOLO(yolo_choices[1]).to(device)

# results = yolo_model.train(data = YOLO_ROOT_DIR, batch = 32,
#                       imgsz = IMAGE_SIZE, optimizer = 'Adam',
#                       lr0 = 0.001, lrf = 0.0001,
#                       plots = True, epochs = 5,
#                       device = device)

# save_model(yolo_model)

Testing YOLOv8¶

In [177]:
# Testing YOLO model

# BEST_MODEL_PATH = HOME_DIR + 'runs/classify/train/weights/best.pt'   # path requires change for every new training!

# yolo_test_model = YOLO(BEST_MODEL_PATH)

# yolo_test_results = yolo_test_model.predict(TEST_DIR)
In [178]:
# Image processing (for interpretations)

# def preprocess_image(image_path):
#     transform = transforms.Compose(
#         [
#             transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
#             transforms.ToTensor(),
#             transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
#         ]
#     )
#     image = Image.open(image_path).convert("RGB")
#     image_tensor = transform(image)
#     return image_tensor

Model Interpretation for YOLOv8 model¶

In [179]:
# YOLO Model Interpretation (NEEDS WORK)

# take a random image, process it, and produce interpretations.
# image = preprocess_image(TEST_DIR + '/ZJ001110.jpg")
# image_processed = image.unsqueeze(0)

# with torch.enable_grad():
#     results = yolo_model(image_processed)
#     probs = results[0].probs
#     target_index = torch.argmax(probs.data).item()

# def predict_wrapper(image_tensor):
#     results = yolo_model(image_tensor)
#     probs_tensor = results[0].probs.data.unsqueeze(0)
#     print(probs_tensor.shape)
#     return probs_tensor

# yolo_ig = IntegratedGradients(predict_wrapper)
# yolo_attributions = yolo_ig.attribute(image_processed, target = target_index, n_steps = 50)

Image processing¶

In [180]:
class ImagesDataset(Dataset):
    """Reads in an image, transforms pixel values, and serves
    a dictionary containing the image id, image tensors, and label.
    """

    def __init__(self, x_df, y_df = None):
        self.data = x_df
        self.label = y_df
        self.transform = transforms.Compose(
            [
                transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)
                ),
            ]
        )

    def __getitem__(self, index):
        image = Image.open(self.data.iloc[index]["filepath"]).convert("RGB")
        image = self.transform(image)
        image_id = self.data.index[index]
        # if we don't have labels (e.g. for test set) just return the image and image id
        if self.label is None:
            sample = {"image_id": image_id, "image": image}
        else:
            label = torch.tensor(self.label.iloc[index].values, 
                                 dtype = torch.float)
            sample = {"image_id": image_id, "image": image, "label": label}
        return sample

    def __len__(self):
        return len(self.data)

Dataloaders for torch models¶

In [181]:
from torch.utils.data import DataLoader

# Create a copy of x_train with modified file paths
x_train_modified = x_train.copy()
x_train_modified['filepath'] = x_train_modified['filepath'].apply(lambda x: os.path.join(IMAGE_DIR, x.split('/')[-1]))

# train dataloader
train_dataset = ImagesDataset(x_train_modified, y_train)
train_dataloader = DataLoader(train_dataset, batch_size = 16)

# validate dataloader
val_dataset = ImagesDataset(x_val, y_val)
val_dataloader = DataLoader(val_dataset, batch_size = 16)

# test dataloader
test_dataset = ImagesDataset(x_test, y_test)
test_dataloader = DataLoader(test_dataset, batch_size = 16)

ResNet50 model¶

In [182]:
# Defining ResNet50 model

# resnet_model = models.resnet50(pretrained = True).to(device)
# resnet_model.fc = nn.Sequential(
#     nn.Linear(2048, 100),  # dense layer takes a 2048-dim input and outputs 100-dim
#     nn.ReLU(inplace = True),  # ReLU activation introduces non-linearity
#     nn.Dropout(0.1),  # common technique to mitigate overfitting
#     nn.Linear(
#         100, 8
#     ),  # final dense layer outputs 8-dim corresponding to our target classes
# ).to(device)
In [183]:
# criterion = nn.CrossEntropyLoss()   # loss function
# optimizer = optim.SGD(resnet_model.parameters(), lr = 0.0001)

Training and validating ResNet50¶

In [184]:
# num_epochs = 5
# tracking_loss = {}
# #early_stopping = EarlyStopping(monitor = "val_loss",
# #                               patience = 5, verbose = True)   # callback

# for epoch in range(1, num_epochs + 1):
#     print(f"Starting epoch {epoch}")
    # TRAINING LOOP
#     epoch_loss = 0.0
#     # iterating through the dataloader batches. tqdm keeps track of progress.
#     with tqdm(enumerate(train_dataloader), total = len(train_dataloader)) as pbar:
#         for batch_n, batch in pbar:
#             optimizer.zero_grad()
#             batch['image'] = batch['image'].to(device)
#             batch['label'] = batch['label'].to(device)
#             outputs = resnet_model(batch["image"])
#             loss = criterion(outputs, batch["label"])
#             tracking_loss[(epoch, batch_n)] = float(loss.cpu())
#             epoch_loss += float(loss.cpu())
#             pbar.set_postfix(loss = f"{loss.item():.4f}")
#             loss.backward()
#             optimizer.step()

#     average_epoch_loss = epoch_loss / len(train_dataloader)
#     print(f"Epoch {epoch} Average training loss: {average_epoch_loss:.4f}")

#     # VALIDATION LOOP
#     resnet_model.eval()  # setting the model to evaluation mode
#     val_loss = 0.0

#     # detach from the computational graph
#     with torch.no_grad():
#         for batch in tqdm(val_dataloader, total = len(val_dataloader)):
#             batch["image"] = batch["image"].to(device)
#             batch['label'] = batch['label'].to(device)
#             logits = resnet_model.forward(batch["image"])
#             loss = criterion(logits, batch["label"])
#             val_loss += float(loss.cpu())

#     average_val_loss = val_loss / len(val_dataloader)
#     print(f"Epoch {epoch} Average Validation Loss: {average_val_loss:.4f}")
In [185]:
# learning curve

# tracking_loss = pd.Series(tracking_loss)

# plt.figure(figsize = (10, 5))
# tracking_loss.plot(alpha = 0.2, label = "loss")
# tracking_loss.rolling(center = True, min_periods = 1, window = 10).mean().plot(
#     label = "loss (moving avg)"
# )
# plt.xlabel("(Epoch, Batch)")
# plt.ylabel("Loss")
# plt.legend(loc = 0)
# plt.show()
In [186]:
# removing some cache
# with torch.no_grad():
#     torch.cuda.empty_cache()
In [187]:
# torch.save(resnet_model, "resnet_model.pth")   # save model
In [188]:
# loaded_model = torch.load("resnet_model.pth")  # load saved model

Testing ResNet50¶

In [189]:
# Testing ResNet50 model

# preds_collector = []
# test_loss = 0.0
# resnet_model.eval()

# with torch.no_grad():
#     for batch in tqdm(test_dataloader, total = len(test_dataloader)):
#         batch["image"] = batch["image"].to(device)
#         if "label" in batch:
#             batch["label"] = batch["label"].to(device)
#         # run the forward step
#         logits = resnet_model.forward(batch["image"])
#         #loss = criterion(logits, batch["label"])
#         #test_loss += float(loss.cpu())
#         # apply softmax so that model outputs are in range [0,1]
#         preds = nn.functional.softmax(logits, dim = 1)
#         # store this batch's predictions in df
#         # detaching from computational graph before converting to numpy arrays
#         preds_df = pd.DataFrame(
#             preds.detach().cpu().numpy(),
#             index = batch["image_id"],
#             columns = species_labels,
#         )
#         preds_collector.append(preds_df)

# submission_df = pd.concat(preds_collector)
# submission_df
In [190]:
# submission_df.to_csv("submission_df.csv")  

Model Interpretation for ResNet50 model¶

In [191]:
# # loading a sample image
# sample_image = preprocess_image(TEST_DIR + '/ZJ001110.jpg')
# sample_processed_img = sample_image.unsqueeze(0).to(device)
# resnet_model = resnet_model.to(device)

# # inference
# output = resnet_model(sample_processed_img)
# print('Output requires gradient? :', output.requires_grad)
# target_class_index = output.argmax(dim = 1).item()  # Assuming prediction is a class probability distribution

# # Captum interpretation
# ig = IntegratedGradients(resnet_model)
# attributions = ig.attribute(sample_processed_img, target = target_class_index, n_steps = 200)
# print(attributions.shape)
In [192]:
# from matplotlib.colors import LinearSegmentedColormap
# from captum.attr import visualization as viz

# default_cmap = LinearSegmentedColormap.from_list('custom blue',[(0, '#ffffff'),
#                                                   (0.25, '#000000'), (1, '#000000')], N = 256)

# _ = viz.visualize_image_attr(np.transpose(attributions.squeeze().cpu().detach().numpy(), (1,2,0)),
#                              np.transpose(sample_image.squeeze().cpu().detach().numpy(), (1,2,0)),
#                              method = 'heat_map', cmap = default_cmap, show_colorbar = True,
#                              sign = 'positive', outlier_perc = 1)
# plt.show()
# plt.savefig('resnet50_captum2.jpg', dpi = 1000)
In [193]:
# # NoiseTunnel attributions interpretation (might throw CUDA memory error, so pls free some GPU memory)

# from captum.attr import NoiseTunnel

# noise_tunnel = NoiseTunnel(ig)

# attributions_nt = noise_tunnel.attribute(sample_processed_img, nt_samples = 10,
#                                          nt_type = 'smoothgrad_sq',
#                                          target = target_class_index)

# _ = viz.visualize_image_attr_multiple(np.transpose(attributions_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
#                                       np.transpose(sample_image.squeeze().cpu().detach().numpy(), (1,2,0)),
#                                       ["original_image", "heat_map"],["all", "positive"],
#                                       cmap = default_cmap, show_colorbar = True)

# plt.show()
# plt.savefig('resnet50_noisetunnel2.jpg', dpi = 1000)

Vanilla CNN¶

In [194]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset



# Define the PyTorch model
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * (IMAGE_SIZE // 4) * (IMAGE_SIZE // 4), IMAGE_SIZE)
        self.fc2 = nn.Linear(IMAGE_SIZE, 8)
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = torch.relu(x)  
        x = self.fc2(x)
        return x

# Create the model
model = ConvNet()

# Define the loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(device)
# model.to(device)

def train_model(model, dataloader, criterion, optimizer, num_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()
    
    for epoch in range(1, num_epochs + 1):
        running_loss = 0.0
        correct = 0
        total = 0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch} - Training"):
            images, labels = batch['image'].to(device), batch['label'].to(device)  # Adjusted this line
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            labels_indices = torch.argmax(labels, dim=1)
            loss = criterion(outputs, labels_indices)
             # Calculate training accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels_indices).sum().item()
            total += labels.size(0)
        
        train_loss = running_loss / len(dataloader)
        train_accuracy = correct / total

        val_loss, val_accuracy = evaluate_model(model, val_dataloader, criterion)

        
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
In [195]:
def evaluate_model(model, dataloader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            images, labels = batch['image'].to(device), batch['label'].to(device)
            outputs = model(images)
            # Convert one-hot encoded labels to class indices
            labels_indices = torch.argmax(labels, dim=1)
            loss = criterion(outputs, labels_indices)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels_indices.size(0)
            correct += (predicted == labels_indices).sum().item()

    avg_loss = running_loss / len(dataloader.dataset)  # Normalizing by the size of the dataset
    accuracy = correct / total
#     print(f'Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')
    return avg_loss, accuracy
In [196]:
from sklearn.metrics import precision_score, recall_score, f1_score

def evaluate_model_with_metrics(model, dataloader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    running_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            images, labels = batch['image'].to(device), batch['label'].to(device)
            outputs = model(images)
            # Convert one-hot encoded labels to class indices
            labels_indices = torch.argmax(labels, dim=1)
            loss = criterion(outputs, labels_indices)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels_indices.size(0)
            correct += (predicted == labels_indices).sum().item()
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels_indices.cpu().numpy())

    avg_loss = running_loss / len(dataloader.dataset)  # Normalizing by the size of the dataset
    accuracy = correct / total
    precision = precision_score(all_labels, all_predictions, average='weighted')
    recall = recall_score(all_labels, all_predictions, average='weighted')
    f1 = f1_score(all_labels, all_predictions, average='weighted')

    return avg_loss, accuracy, precision, recall, f1
In [201]:
train_model(model, train_dataloader, criterion, optimizer, num_epochs=8)
Epoch 1 - Training: 100%|██████████| 742/742 [01:36<00:00,  7.72it/s]
Evaluating: 100%|██████████| 186/186 [00:17<00:00, 10.35it/s]
Epoch 1: Train Loss: 0.3639, Train Accuracy: 0.8769, Val Loss: 0.0406, Val Accuracy: 0.8036

In [202]:
#Evaluate Model with metrics
test_loss, test_accuracy, precision, recall, f1 = evaluate_model_with_metrics(model, test_dataloader, criterion)
print(f'Test - Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}')
Evaluating: 100%|██████████| 104/104 [00:10<00:00, 10.09it/s]
Test - Loss: 0.0408, Accuracy: 0.8023, Precision: 0.7990, Recall: 0.8023, F1-Score: 0.7970

In [207]:
def compute_confusion_matrix(model, dataloader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    all_predictions = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Creating Confusion Matrix"):
            images, labels = batch['image'].to(device), batch['label'].to(device)
            outputs = model(images)
            # Convert one-hot encoded labels to class indices
            labels_indices = torch.argmax(labels, dim=1)
            _, predicted = torch.max(outputs, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels_indices.cpu().numpy())

    return confusion_matrix(all_labels, all_predictions)

cm = compute_confusion_matrix(model, test_dataloader)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=species_labels, yticklabels=species_labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
Creating Confusion Matrix: 100%|██████████| 104/104 [00:09<00:00, 10.59it/s]
In [ ]: